package tree;


public class MaxDepth {
    public static void main(String[] args) {
        MaxDepth maxDepth = new MaxDepth();
        TreeNode treeNode = new TreeNode(3);
        treeNode.left = new TreeNode(9);
        treeNode.right = new TreeNode(20);
        treeNode.left.left = null;
        treeNode.left.right = null;
        treeNode.right.left = new TreeNode(15);
        treeNode.right.right = new TreeNode(7);
        int max = maxDepth.maxDepth(treeNode);
        System.out.println("max -> " + max);
    }

    public int maxDepth(TreeNode treeNode) {
        if (treeNode == null) {
            return 0;
        }
        return Math.max(maxDepth(treeNode.left), maxDepth(treeNode.right)) + 1;
    }
}
